# Very loosely inspired by indexed_dataset in Fairseq, Megatron
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py


import os
import random
import struct
from typing import Union

import numpy as np
import torch
from torch.utils.data import IterableDataset, get_worker_info

dtypes = {
    1: np.uint8,
    2: np.int8,
    3: np.int16,
    4: np.int32,
    5: np.int64,
    6: np.float32,
    7: np.float64,
    8: np.uint16,
}


def code(dtype):
    for k in dtypes:
        if dtypes[k] == dtype:
            return k
    raise ValueError(dtype)


HDR_MAGIC = b'LITPKDS'
HDR_SIZE = 24  # bytes

# Recreate the StrPath type for sanity
StrPath = Union[str, bytes, os.PathLike]


class PackedDataset(IterableDataset):
    def __init__(
        self,
        filenames,
        n_chunks,
        block_size,
        seed=12345,
        shuffle=True,
        wrap=False,
        num_processes=1,
        process_rank=0,
    ):
        self._filenames = filenames
        self._n_chunks = n_chunks
        self._block_size = block_size
        self._seed = seed
        self._shuffle = shuffle
        self._wrap = wrap
        self._num_processes = num_processes
        self._process_rank = process_rank

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        num_shards = num_workers * self._num_processes
        shard_id = self._process_rank * num_workers + worker_id

        max_num_files = len(self._filenames) // num_shards * num_shards
        filenames = self._filenames[shard_id:max_num_files:num_shards]

        return PackedDatasetIterator(
            filenames=filenames,
            n_chunks=self._n_chunks,
            block_size=self._block_size,
            seed=self._seed,
            shuffle=self._shuffle,
            wrap=self._wrap,
        )


class PackedDatasetBuilder(object):
    """
    A class designed to manage the packing and storage of large arrays into binary files with a specified chunk size.

    This class handles the division of large arrays into smaller 'chunks' that are stored individually in binary files.
    Each file begins with a header specifying metadata such as data type and version. This facilitates the management
    of potentially large datasets that need to be processed or transmitted in smaller, more manageable units.

    Packing data works like this:

    1) A big array of chunk size is created with prefilled with pad tokens.
    2) When #add_array is called and given the tokenized

    Parameters:
        outdir (str): The output directory where the chunk files will be stored.
        prefix (str): The prefix to use for naming the chunk files.
        chunk_size (int): The maximum number of elements each chunk file should contain.
        pad_token (int): Incomplete chunks will be filled with pad_token.
        dtype (str or numpy.dtype, optional): The data type of the array elements. If 'auto', the dtype is determined based on `vocab_size`.
            Defaults to 'auto'.
        vocab_size (int, optional): The maximum size of the vocabulary. Required if dtype is 'auto'.
    """

    def __init__(
        self,
        outdir: StrPath,
        prefix: str,
        chunk_size: int,
        pad_token: int,
        dtype='auto',
        vocab_size=None,
    ):
        if dtype == 'auto':
            if vocab_size is None:
                raise ValueError("vocab_size cannot be None when dtype='auto'")
            if vocab_size is not None and vocab_size < 65500:
                self._dtype = np.uint16
            else:
                self._dtype = np.int32
        else:
            self._dtype = dtype
        self._counter = 0
        self._chunk_size = chunk_size
        self._outdir = outdir
        self._prefix = prefix
        self._pad_token = pad_token

        # Initialise an array with the pad tokens to fill up as we turn file contents into tokens
        self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
        self._arr.fill(self._pad_token)

        self._idx = 0
        self._version = 1
        self._filenames = []

    def _write_chunk(self):
        filename = f'{self._prefix}_{self._counter:010d}.bin'
        filename = os.path.join(self._outdir, filename)

        with open(filename, 'wb') as f:
            # File headers to identify the file type and some meta
            f.write(HDR_MAGIC)
            f.write(struct.pack('<Q', self._version))
            f.write(struct.pack('<B', code(self._dtype)))
            f.write(struct.pack('<Q', self._chunk_size))
            # Now write all of it in.
            f.write(self._arr.tobytes(order='C'))

        self._filenames.append(filename)
        self._counter += 1

        # Have written the file, so we reset the array to start again
        self._arr.fill(self._pad_token)
        self._idx = 0

    @property
    def dtype(self):
        return self._dtype

    @property
    def filenames(self):
        return self._filenames.copy()

    def add_array(self, arr: np.ndarray):
        while self._idx + arr.shape[0] > self._chunk_size:
            part_len = self._chunk_size - self._idx
            self._arr[self._idx : self._idx + part_len] = arr[:part_len]
            self._write_chunk()
            arr = arr[part_len:]

        arr_len = arr.shape[0]
        self._arr[self._idx : self._idx + arr_len] = arr
        self._idx += arr_len

    def write_remainder(self):
        self._write_chunk()


class PackedDatasetIterator:
    def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
        self._seed = seed
        self._shuffle = shuffle
        self._rng = np.random.default_rng(seed) if shuffle else None
        self._block_idxs = None

        self._wrap = wrap

        # TODO: instead of filenames, we could have a single text stream
        #       (or text file) with the sequence of all files to be
        #       fetched/loaded.
        self._filenames = filenames
        self._file_idx = 0

        self._n_chunks = n_chunks

        self._dtype = None
        self._block_size = block_size
        self._n_blocks = None

        self._mmaps = []
        self._buffers = []

        self._block_idxs = []
        self._curr_idx = 0

        self._load_n_chunks()

    def _read_header(self, path):
        with open(path, 'rb') as f:
            magic = f.read(len(HDR_MAGIC))
            assert magic == HDR_MAGIC, "File doesn't match expected format."
            version = struct.unpack('<Q', f.read(8))
            assert version == (1,)
            (dtype_code,) = struct.unpack('<B', f.read(1))
            dtype = dtypes[dtype_code]
            (chunk_size,) = struct.unpack('<Q', f.read(8))
        return dtype, chunk_size

    def _close_mmaps(self):
        for mmap in self._mmaps:
            mmap._mmap.close()

    def _load_n_chunks(self):
        self._close_mmaps()
        self._mmaps = []
        self._buffers = []

        # if n_chunks is larger than the number of files assigned
        if self._n_chunks > len(self._filenames[self._file_idx :]):
            if not self._wrap:
                print(
                    'No more chunks, stopping. (Note: If this happens when preparing data, see https://github.com/Lightning-AI/lit-llama/issues/425)'
                )
                raise StopIteration

            self._file_idx = 0

        for i in range(self._n_chunks):
            filename = self._filenames[self._file_idx + i]
            if self._dtype is None:
                self._dtype, self._chunk_size = self._read_header(filename)
                self._n_blocks = self._chunk_size // self._block_size
            # TODO: check header matches with previous files
            mmap = np.memmap(filename, mode='r', order='C', offset=HDR_SIZE)
            self._mmaps.append(mmap)
            self._buffers.append(memoryview(mmap))

        self._file_idx += self._n_chunks
        n_all_blocks = self._n_chunks * self._n_blocks

        self._block_idxs = (
            self._rng.permutation(n_all_blocks)
            if self._shuffle
            else range(n_all_blocks)
        )

        self._curr_idx = 0

    def __del__(self):
        self._close_mmaps()
        del self._mmaps
        del self._buffers

    def __iter__(self):
        return self

    def __next__(self):
        if self._curr_idx >= len(self._block_idxs):
            self._load_n_chunks()
            # TODO: trigger fetching next next n_chunks if remote
        block_idx = self._block_idxs[self._curr_idx]
        chunk_id = block_idx // self._n_blocks
        buffer = self._buffers[chunk_id]
        elem_id = (block_idx % self._n_blocks) * self._block_size
        offset = np.dtype(self._dtype).itemsize * elem_id
        arr = np.frombuffer(
            buffer, dtype=self._dtype, count=self._block_size, offset=offset
        )
        self._curr_idx += 1
        return torch.from_numpy(arr.astype(np.int64))


class CombinedDataset(IterableDataset):
    def __init__(self, datasets, seed, weights=None):
        self._seed = seed
        self._datasets = datasets
        self._weights = weights
        n_datasets = len(datasets)
        if weights is None:
            self._weights = [1 / n_datasets] * n_datasets

    def __iter__(self):
        return CombinedDatasetIterator(self._datasets, self._seed, self._weights)


class CombinedDatasetIterator:
    def __init__(self, datasets, seed, weights):
        self._datasets = [iter(el) for el in datasets]
        self._weights = weights
        self._rng = random.Random(seed)

    def __next__(self):
        (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
        return next(dataset)
